library(tidyverse)
library(tidymodels)
library(workflows)
library(magrittr)
source('data_retrieval.R')
source('preprocess.R')
source('modeling.R')
source('data_retrieval.R')

rm(roc_auc)

revision_clalit_adm <-  get_table(con =con ,
                                  schema = "schema_name",
                                  name = "Expenses_AdmissionFeatures") %>%
  mutate(day_of_admission = as.Date(taarich_ishpuz)) %>% 
  select (-c("taarich_shichrur"))

cols_admission_temp <- revision_clalit_adm %>% dbplyr::op_vars()
cols_admission <- setdiff(cols_admission_temp, c("id_var", "zihui_ishpuz_bikur", "index_date", "day_of_admission"))

all_events <- get_revision_cancer_events_data(con = con)  %>% 
  mutate(  EVE_category_order =   case_when(EVE_category == 'urgent' ~ 1,
                                            EVE_category == 'elective' ~ 2 ,
                                            EVE_category == 'ER' ~ 3,
                                            EVE_category == 'radiotherapy' ~ 4,
                                            EVE_category == 'drugs' ~ 5,
                                            EVE_category == 'start' ~  6 ,
                                            TRUE ~ NULL ))

cols_features <- all_events %>%
  select (-c("EVE_category_order"))  %>% 
  dbplyr::op_vars()

revision_clalit_adm_full <- revision_clalit_adm %>%
                          inner_join(all_events,
                                     by = c("id_var" ="id_var"  ,
                                            "day_of_admission"="S_index_date_XX" )) %>%
                          filter(EVE_category %in% c('elective','urgent')) %>% 
                          group_by(id_var,  taarich_ishpuz) %>% 
                          mutate(rn  = row_number(EVE_category_order)) %>% 
                          filter(rn == 1) %>% 
                          select (-c("EVE_category_order","id_var","index_date","S_zihui_ishpuz_bikur_XX")) %>% 
                          rename("id_var" = "id_var") %>%  
                          rename("S_index_date_XX" = "day_of_admission") %>% 
                          rename("S_zihui_ishpuz_bikur_XX" = "zihui_ishpuz_bikur")




set.seed(42)
(revision_cancer_clalit_adm <- 
    revision_clalit_adm_full
  %>% add_UTL_features()
  %>% ungroup()
  %>% eol_cleaning_pipeline()
  %>% mutate(across(where(is.character),as.factor))
)

(revision_cancer_clalit_adm_split <- 
    revision_cancer_clalit_adm
  %>% join_phat0(file_name = "revision_p0.RData")
  %>% create_hashed_split(0.5)
)


calibration_proportion <- get_calibration_proportion(revision_cancer_clalit_adm_split$train, DMG_died_within_365d)
downsampled_train <- downsample_data(revision_cancer_clalit_adm_split$train, DMG_died_within_365d)

(initial_model_params <- list(tree_depth = 8, learn_rate = 0.123001965989824,
                              loss_reduction = 11.1582974204794, min_n = 2.44744561146945,
                              sample_size = 0.773196286475286,
                              mtry = 0.549052911438048, trees = 400)
)

tic <- Sys.time() 
best_params <- run_bayes_optimisation(downsampled_train, DMG_died_within_365d, initial_model_params, 50) %>%
  select_best() %>% 
  select(-.config)


# train model 
print(paste0("Hyperparamter search took ", difftime(Sys.time(), tic, units = "hours"), " hours!"))
best_model_spec <- get_xgboost_spec() %>% set_args(nthread = 40, !!!best_params)

toc <- Sys.time()
(downsampled_train 
  %>% train_model(best_model_spec, outcome = DMG_died_within_365d)
  -> xgboost_fitted
)

print(paste0("Model training took ", difftime(Sys.time(), toc, units = "hours"), " hours!"))

preds_test <- get_model_preds(model = xgboost_fitted, 
                              data = revision_cancer_clalit_adm_split$test, 
                              true_outcome = "DMG_died_within_365d")


preds_test_calibrate <- calibrate_preds(preds_test, calibration_proportion, .pred_1)

roc_auc <- roc_auc(preds_test_calibrate , truth = true_value, .pred_1)

feature_importance <- get_feature_importance(xgboost_fitted)
fit_xgboost <- pull_workflow_fit(xgboost_fitted)

cancer_revision_clalit_all_model = list(fit_xgboost = fit_xgboost,
                                            preds_test = preds_test_calibrate,
                                            roc_auc = roc_auc, 
                                            feature_importance = feature_importance)

save(cancer_revision_clalit_all_model,
     file = "revision_all_model.RData")

rm(roc_auc)

##### train AUC : 


  initial_predict <- get_load_new(file_name = "cancer_revision_clalit_all_model.RData")
  
  recipe <- make_matrix_recipe(downsampled_train, DMG_died_within_365d) %>% prep()
  train_data_preprocessed <- recipe %>% juice()

  preds <- predict(initial_predict$fit_xgboost,
                   train_data_preprocessed %>% select(-DMG_died_within_365d), 
                   type = "prob")
  
  preds_train_calibrate <- calibrate_preds(preds, calibration_proportion, .pred_1)
  
  roc_auc_train <- roc_auc(preds_train_calibrate ,
                           truth = train_data_preprocessed$DMG_died_within_365d,
                           preds_after_bayes)
  
  rbind(data.frame(initial_predict$roc_auc) %>% mutate(split = "test"), 
                     data.frame(roc_auc_train) %>% mutate(split = "train")) %>%
    mutate(sample = "clalit admissions all features") %>% 
    write_csv( .,"revision_all_model_AUC.csv")





### without admission:  -----------------------

EMR_to_drop <- setdiff(grep(c("BT.*val|COV_bp") ,cols_features, value= T),
                       "COV_bp_days_from_last_exam" )

(revision_cancer_clalit_adm_split_noadm <- 
    revision_cancer_clalit_adm 
  %>%  select(-c(cols_admission,EMR_to_drop))
  %>% join_phat0(file_name = "revision_p0.RData")
  %>% create_hashed_split(0.5)
)


calibration_proportion <- get_calibration_proportion(revision_cancer_clalit_adm_split_noadm$train, DMG_died_within_365d)
downsampled_train <- downsample_data(revision_cancer_clalit_adm_split_noadm$train, DMG_died_within_365d)

(initial_model_params <- list(tree_depth = 8, learn_rate = 0.123001965989824,
                              loss_reduction = 11.1582974204794, min_n = 2.44744561146945,
                              sample_size = 0.773196286475286,
                              mtry = 0.549052911438048, trees = 400)
)

tic <- Sys.time() 
best_params <- run_bayes_optimisation(downsampled_train, DMG_died_within_365d, initial_model_params, 50) %>%
  select_best() %>% 
  select(-.config)


# train model 
print(paste0("Hyperparamter search took ", difftime(Sys.time(), tic, units = "hours"), " hours!"))
best_model_spec <- get_xgboost_spec() %>% set_args(nthread = 40, !!!best_params)

toc <- Sys.time()
(downsampled_train 
  %>% train_model(best_model_spec, outcome = DMG_died_within_365d)
  -> xgboost_fitted
)

print(paste0("Model training took ", difftime(Sys.time(), toc, units = "hours"), " hours!"))

preds_test <- get_model_preds(model = xgboost_fitted, 
                              data = revision_cancer_clalit_adm_split_noadm$test, 
                              true_outcome = "DMG_died_within_365d")


preds_test_calibrate <- calibrate_preds(preds_test, calibration_proportion, .pred_1)

roc_auc <- roc_auc(preds_test_calibrate , truth = true_value, .pred_1)

feature_importance <- get_feature_importance(xgboost_fitted)
fit_xgboost <- pull_workflow_fit(xgboost_fitted)

cancer_revision_clalit_all_model_noadm = list(fit_xgboost = fit_xgboost,
                                        preds_test = preds_test_calibrate,
                                        roc_auc = roc_auc, 
                                        feature_importance = feature_importance)

save(cancer_revision_clalit_all_model_noadm,
     file = "revision_all_model_noadm.RData")





##### train AUC : 


initial_predict <- get_load_new(file_name = "revision_all_model_noadm.RData")

recipe <- make_matrix_recipe(downsampled_train, DMG_died_within_365d) %>% prep()
train_data_preprocessed <- recipe %>% juice()

preds <- predict(initial_predict$fit_xgboost,
                 train_data_preprocessed %>% select(-DMG_died_within_365d), 
                 type = "prob")

preds_train_calibrate <- calibrate_preds(preds, calibration_proportion, .pred_1)

roc_auc_train <- roc_auc(preds_train_calibrate ,
                         truth = train_data_preprocessed$DMG_died_within_365d,
                         preds_after_bayes)

rbind(data.frame(initial_predict$roc_auc) %>% mutate(split = "test"), 
      data.frame(roc_auc_train) %>% mutate(split = "train")) %>%
  mutate(sample = "clalit admissions No EMR") %>% 
  write_csv( .,"revision_noEMR_model_AUC.csv")







